"""
Random variables for sampling camera poses
Author: Jeff Mahler
"""
import copy
import logging

import numpy as np
import scipy.stats as ss

from autolab_core import Point, RigidTransform, RandomVariable
from autolab_core.utils import sph2cart, cart2sph
from perception import CameraIntrinsics, BinaryImage, ColorImage, DepthImage, ObjectRender, RenderMode

from mesh_renderer import VirtualCamera, SceneObject

class CameraSample(object):
    """ Struct to encapsulate the results of sampling a camera and its pose. """
    def __init__(self, object_to_camera_pose, camera_intr,
                 radius, elev, az, roll, tx=0, ty=0, focal=0, 
                 cx=0, cy=0):
        self.object_to_camera_pose = object_to_camera_pose
        self.camera_intr = camera_intr
        self.radius = radius
        self.elev = elev
        self.az = az
        self.roll = roll
        self.tx = tx
        self.ty = ty
        self.focal = focal
        self.cx = cx
        self.cy = cy

    @property
    def T_camera_world(self):
        return self.object_to_camera_pose.inverse().as_frames(self.camera_intr.frame, 'world')

class RenderSample(object):
    """ Struct to encapsulate the results of sampling rendered images from a camera. """
    def __init__(self, renders, camera):
        self.renders = renders
        self.camera = camera

class UniformViewsphereRandomVariable(RandomVariable):
    """
    Uniform distribution over a bounded region of a viewing sphere.
    """
    def __init__(self, min_radius, max_radius,
                 min_elev, max_elev,
                 min_az=0, max_az=2*np.pi,
                 min_roll=0, max_roll=2*np.pi,
                 num_prealloc_samples=1):
        """Initialize a ViewsphereDiscretizer.

        Parameters
        ----------
        min_radius : float
            Minimum radius for viewing sphere.
        max_radius : float
            Maximum radius for viewing sphere.
        min_elev : float
            Minimum elevation (angle from z-axis) for camera position.
        max_elev : float
            Maximum elevation for camera position.
        min_az : float
            Minimum azimuth (angle from x-axis) for camera position.
        max_az : float
            Maximum azimuth for camera position.
        min_roll : float
            Minimum roll (rotation of camera about axis generated by azimuth and
            elevation) for camera.
        max_roll : float
            Maximum roll for camera.
        num_prealloc_samples : int
            Number of preallocated samples.
        """
        # read params
        self.min_radius = min_radius
        self.max_radius = max_radius
        self.min_az = min_az * np.pi
        self.max_az = max_az * np.pi
        self.min_elev = min_elev * np.pi
        self.max_elev = max_elev * np.pi
        self.min_roll = min_roll * np.pi
        self.max_roll = max_roll * np.pi
        self.num_prealloc_samples = num_prealloc_samples

        # setup random variables
        self.rad_rv = ss.uniform(loc=self.min_radius, scale=self.max_radius-self.min_radius)
        self.elev_rv = ss.uniform(loc=self.min_elev, scale=self.max_elev-self.min_elev)
        self.az_rv = ss.uniform(loc=self.min_az, scale=self.max_az-self.min_az)
        self.roll_rv = ss.uniform(loc=self.min_roll, scale=self.max_roll-self.min_roll)

        RandomVariable.__init__(self, self.num_prealloc_samples)
        
    def object_to_camera_pose(self, radius, elev, az, roll):
        """ Convert spherical coords to an object-camera pose. """
        # generate camera center from spherical coords
        camera_center_obj = np.array([sph2cart(radius, az, elev)]).squeeze()
        camera_z_obj = -camera_center_obj / np.linalg.norm(camera_center_obj)

        # find the canonical camera x and y axes
        camera_x_par_obj = np.array([camera_z_obj[1], -camera_z_obj[0], 0])
        if np.linalg.norm(camera_x_par_obj) == 0:
            camera_x_par_obj = np.array([1, 0, 0])
        camera_x_par_obj = camera_x_par_obj / np.linalg.norm(camera_x_par_obj)
        camera_y_par_obj = np.cross(camera_z_obj, camera_x_par_obj)
        camera_y_par_obj = camera_y_par_obj / np.linalg.norm(camera_y_par_obj)
        if camera_y_par_obj[2] > 0:
            camera_x_par_obj = -camera_x_par_obj
            camera_y_par_obj = np.cross(camera_z_obj, camera_x_par_obj)
            camera_y_par_obj = camera_y_par_obj / np.linalg.norm(camera_y_par_obj)

        # rotate by the roll
        R_obj_camera_par = np.c_[camera_x_par_obj, camera_y_par_obj, camera_z_obj]
        R_camera_par_camera = np.array([[np.cos(roll), -np.sin(roll), 0],
                                        [np.sin(roll), np.cos(roll), 0],
                                        [0, 0, 1]])
        R_obj_camera = R_obj_camera_par.dot(R_camera_par_camera)
        t_obj_camera = camera_center_obj

        # create final transform
        T_obj_camera = RigidTransform(R_obj_camera, t_obj_camera,
                                      from_frame=self.frame, to_frame='obj')
        return T_obj_camera.inverse()

    def sample(self, size=1):
        """ Sample random variables from the model.

        Parameters
        ----------
        size : int
            number of sample to take
        
        Returns
        -------
        :obj:`list` of :obj:`RigidTransform`
            sampled object to camera poses
        """
        samples = []
        for i in range(size):
            # sample params
            radius = self.rad_rv.rvs(size=1)[0]
            elev = self.elev_rv.rvs(size=1)[0]
            az = self.az_rv.rvs(size=1)[0]
            roll = self.roll_rv.rvs(size=1)[0]

            # convert to camera pose
            samples.append(self.object_to_camera_pose(radius, elev, az, roll))

        # not a list if only 1 sample
        if size == 1:
            return samples[0]
        return samples

class UniformPlanarWorksurfaceRandomVariable(RandomVariable):
    """
    Uniform distribution over camera poses and intrinsics for a bounded region of a viewing sphere and planar worksurface.
    """
    def __init__(self, frame, config, num_prealloc_samples=1):
        """Initialize a ViewsphereDiscretizer.

        Parameters
        ----------
        frame: :obj:`str`
            string name of the camera frame
        config : :obj:`autolab_core.YamlConfig`
            configuration containing parameters of random variable
        num_prealloc_samples : int
            Number of preallocated samples.

        Notes
        -----
        Required parameters of config are specified in Other Parameters

        Other Parameters
        ----------
        min_f : float
            Minimum focal length of camera
        max_f : float
            Maximum focal length of camera
        min_cx : float
            Minimum camera optical center in x
        max_cx : float
            Maximum camera optical center in x
        min_cy : float
            Minimum camera optical center in y
        max_cy : float
            Maximum camera optical center in y
        im_height : int
            Height of camera image
        im_width : int
            Width of camera image
        min_radius : float
            Minimum radius for viewing sphere.
        max_radius : float
            Maximum radius for viewing sphere.
        min_elev : float
            Minimum elevation (angle from z-axis), in degrees, for camera position.
        max_elev : float
            Maximum elevation for camera position, in degrees.
        min_az : float
            Minimum azimuth (angle from x-axis), in degrees, for camera position.
        max_az : float
            Maximum azimuth, in degrees, for camera position.
        min_roll : float
            Minimum roll (rotation of camera about axis generated by azimuth and
            elevation), in degrees, for camera.
        max_roll : float
            Maximum roll, in degrees, for camera.
        min_x : float
            Minimum x translation of object on table
        max_x : float
            Maximum x translation of object on table
        min_y : float
            Minimum y translation of object on table
        max_y : float
            Maximum y translation of object on table
        """
        # read params
        self.frame = frame
        self.config = config
        self.num_prealloc_samples = num_prealloc_samples

        self._parse_config(config)

        # setup random variables

        # camera
        self.focal_rv = ss.uniform(loc=self.min_f, scale=self.max_f-self.min_f)
        self.cx_rv = ss.uniform(loc=self.min_cx, scale=self.max_cx-self.min_cx)
        self.cy_rv = ss.uniform(loc=self.min_cy, scale=self.max_cy-self.min_cy)

        # viewsphere
        self.rad_rv = ss.uniform(loc=self.min_radius, scale=self.max_radius-self.min_radius)
        self.elev_rv = ss.uniform(loc=self.min_elev, scale=self.max_elev-self.min_elev)
        self.az_rv = ss.uniform(loc=self.min_az, scale=self.max_az-self.min_az)
        self.roll_rv = ss.uniform(loc=self.min_roll, scale=self.max_roll-self.min_roll)

        # table translation
        self.tx_rv = ss.uniform(loc=self.min_x, scale=self.max_x-self.min_x)
        self.ty_rv = ss.uniform(loc=self.min_y, scale=self.max_y-self.min_y)

        RandomVariable.__init__(self, self.num_prealloc_samples)
        
    def _parse_config(self, config):
        """ Reads parameters from the config into class members """
        # camera params
        self.min_f = config['min_f']
        self.max_f = config['max_f']
        self.min_cx = config['min_cx']
        self.max_cx = config['max_cx']
        self.min_cy = config['min_cy']
        self.max_cy = config['max_cy']
        self.im_height = config['im_height']
        self.im_width = config['im_width']

        # viewsphere params
        self.min_radius = config['min_radius']
        self.max_radius = config['max_radius']
        self.min_az = np.deg2rad(config['min_az'])
        self.max_az = np.deg2rad(config['max_az'])
        self.min_elev = np.deg2rad(config['min_elev'])
        self.max_elev = np.deg2rad(config['max_elev'])
        self.min_roll = np.deg2rad(config['min_roll'])
        self.max_roll = np.deg2rad(config['max_roll'])

        # params of translation in plane
        self.min_x = config['min_x']
        self.max_x = config['max_x']
        self.min_y = config['min_y']
        self.max_y = config['max_y']
        
    def object_to_camera_pose(self, radius, elev, az, roll, x, y):
        """ Convert spherical coords to an object-camera pose. """
        # generate camera center from spherical coords
        delta_t = np.array([x, y, 0])
        camera_center_obj = np.array([sph2cart(radius, az, elev)]).squeeze() + delta_t
        camera_z_obj = -np.array([sph2cart(radius, az, elev)]).squeeze()
        camera_z_obj = camera_z_obj / np.linalg.norm(camera_z_obj)
        
        # find the canonical camera x and y axes
        camera_x_par_obj = np.array([camera_z_obj[1], -camera_z_obj[0], 0])
        if np.linalg.norm(camera_x_par_obj) == 0:
            camera_x_par_obj = np.array([1, 0, 0])
        camera_x_par_obj = camera_x_par_obj / np.linalg.norm(camera_x_par_obj)
        camera_y_par_obj = np.cross(camera_z_obj, camera_x_par_obj)
        camera_y_par_obj = camera_y_par_obj / np.linalg.norm(camera_y_par_obj)
        if camera_y_par_obj[2] > 0:
            camera_x_par_obj = -camera_x_par_obj
            camera_y_par_obj = np.cross(camera_z_obj, camera_x_par_obj)
            camera_y_par_obj = camera_y_par_obj / np.linalg.norm(camera_y_par_obj)

        # rotate by the roll
        R_obj_camera_par = np.c_[camera_x_par_obj, camera_y_par_obj, camera_z_obj]
        R_camera_par_camera = np.array([[np.cos(roll), -np.sin(roll), 0],
                                        [np.sin(roll), np.cos(roll), 0],
                                        [0, 0, 1]])
        R_obj_camera = R_obj_camera_par.dot(R_camera_par_camera)
        t_obj_camera = camera_center_obj

        # create final transform
        T_obj_camera = RigidTransform(R_obj_camera, t_obj_camera,
                                      from_frame=self.frame,
                                      to_frame='obj')
                            
        return T_obj_camera.inverse()

    def camera_intrinsics(self, T_camera_obj, f, cx, cy):
        """ Generate shifted camera intrinsics to simulate cropping """
        # form intrinsics
        camera_intr = CameraIntrinsics(self.frame, fx=f, fy=f,
                                       cx=cx, cy=cy, skew=0.0,
                                       height=self.im_height, width=self.im_width)

        # compute new camera center by projecting object 0,0,0 into the camera
        center_obj_obj = Point(np.zeros(3), frame='obj')
        center_obj_camera = T_camera_obj * center_obj_obj
        u_center_obj = camera_intr.project(center_obj_camera)
        camera_shifted_intr = copy.deepcopy(camera_intr)
        camera_shifted_intr.cx = 2 * camera_intr.cx - float(u_center_obj.x)
        camera_shifted_intr.cy = 2 * camera_intr.cy - float(u_center_obj.y)
        return camera_shifted_intr

    def sample(self, size=1):
        """ Sample random variables from the model.

        Parameters
        ----------
        size : int
            number of sample to take
        
        Returns
        -------
        :obj:`list` of :obj:`RigidTransform`
            sampled object to camera poses
        """
        samples = []
        for i in range(size):
            # sample camera params
            focal = self.focal_rv.rvs(size=1)[0]
            cx = self.cx_rv.rvs(size=1)[0]
            cy = self.cy_rv.rvs(size=1)[0]

            # sample viewsphere params
            radius = self.rad_rv.rvs(size=1)[0]
            elev = self.elev_rv.rvs(size=1)[0]
            az = self.az_rv.rvs(size=1)[0]
            roll = self.roll_rv.rvs(size=1)[0]

            # sample plane translation
            tx = self.tx_rv.rvs(size=1)[0]
            ty = self.ty_rv.rvs(size=1)[0]

            logging.debug('Sampled')

            logging.debug('focal: %.3f' %(focal))
            logging.debug('cx: %.3f' %(cx))
            logging.debug('cy: %.3f' %(cy))

            logging.debug('radius: %.3f' %(radius))
            logging.debug('elev: %.3f' %(elev))
            logging.debug('az: %.3f' %(az))
            logging.debug('roll: %.3f' %(roll))

            logging.debug('tx: %.3f' %(tx))
            logging.debug('ty: %.3f' %(ty))

            # convert to pose and intrinsics
            object_to_camera_pose = self.object_to_camera_pose(radius, elev, az, roll,
                                                               tx, ty)
            camera_shifted_intr = self.camera_intrinsics(object_to_camera_pose,
                                                         focal, cx, cy)
            camera_sample = CameraSample(object_to_camera_pose,
                                         camera_shifted_intr,
                                         radius, elev, az, roll, tx=tx, ty=ty,
                                         focal=focal, cx=cx, cy=cy)

            # convert to camera pose
            samples.append(camera_sample)

        # not a list if only 1 sample
        if size == 1:
            return samples[0]
        return samples

class UniformPlanarWorksurfaceImageRandomVariable(RandomVariable):
    """ Random variable for sampling images from a camera """
    def __init__(self, mesh, render_modes, frame, config, stable_pose=None, scene_objs=None, num_prealloc_samples=0):
        """Initialize a ViewsphereDiscretizer.

        Parameters
        ----------
        mesh : :obj:`Mesh3D`
            mesh of the object to render
        render_modes : :obj:`list` of :obj:`perception.RenderMode`
            render modes to use
        frame: :obj:`str`
            string name of the camera frame
        config : :obj:`autolab_core.YamlConfig`
            configuration containing parameters of random variable
        stable_pose : :obj:`StablePose`
            stable pose for the mesh to rest in
        scene_objs : :obj:`dict` mapping :obj:`str` to :obj:`SceneObject`
            objects to render statically in the scene
        num_prealloc_samples : int
            Number of preallocated samples.

        Notes
        -----
        Required parameters of config are specified in Other Parameters

        Other Parameters
        ----------
        min_f : float
            Minimum focal length of camera
        max_f : float
            Maximum focal length of camera
        min_cx : float
            Minimum camera optical center in x
        max_cx : float
            Maximum camera optical center in x
        min_cy : float
            Minimum camera optical center in y
        max_cy : float
            Maximum camera optical center in y
        im_height : int
            Height of camera image
        im_width : int
            Width of camera image
        min_radius : float
            Minimum radius for viewing sphere.
        max_radius : float
            Maximum radius for viewing sphere.
        min_elev : float
            Minimum elevation (angle from z-axis) for camera position.
        max_elev : float
            Maximum elevation for camera position.
        min_az : float
            Minimum azimuth (angle from x-axis) for camera position.
        max_az : float
            Maximum azimuth for camera position.
        min_roll : float
            Minimum roll (rotation of camera about axis generated by azimuth and
            elevation) for camera.
        max_roll : float
            Maximum roll for camera.
        min_x : float
            Minimum x translation of object on table
        max_x : float
            Maximum x translation of object on table
        min_y : float
            Minimum y translation of object on table
        max_y : float
            Maximum y translation of object on table
        """
        # read params
        self.mesh = mesh
        self.render_modes = render_modes
        self.frame = frame
        self.config = config
        self.stable_pose = stable_pose
        self.scene_objs = scene_objs
        self.num_prealloc_samples = num_prealloc_samples

        # init random variables
        self.ws_rv = UniformPlanarWorksurfaceRandomVariable(self.frame, self.config, num_prealloc_samples=self.num_prealloc_samples)

        RandomVariable.__init__(self, self.num_prealloc_samples)
        
    def sample(self, size=1):
        """ Sample random variables from the model.

        Parameters
        ----------
        size : int
            number of sample to take
        
        Returns
        -------
        :obj:`list` of :obj:`RigidTransform`
            sampled object to camera poses
        """
        samples = []
        for i in range(size):
            # sample camera params
            camera_sample = self.ws_rv.sample(size=1)

            # render images
            camera = VirtualCamera(camera_sample.camera_intr)
            for name, scene_obj in self.scene_objs.iteritems():
                camera.add_to_scene(name, scene_obj)

            image_bundle = {}
            for render_mode in self.render_modes:
                images = camera.wrapped_images(self.mesh,
                                               [camera_sample.object_to_camera_pose],
                                               render_mode, stable_pose=self.stable_pose)
                image_bundle[render_mode] = images[0]

            # convert to camera pose
            samples.append(RenderSample(image_bundle, camera_sample))

        # not a list if only 1 sample
        if size == 1:
            return samples[0]
        return samples
        
